# Copyright (c) OpenMMLab. All rights reserved.
import os
from logging import warning
from os import path as osp

import mmcv
import numpy as np
from lyft_dataset_sdk.lyftdataset import LyftDataset as Lyft
from mmdet3d.datasets import LyftDataset
from pyquaternion import Quaternion

from .nuscenes_converter import get_2d_boxes, get_available_scenes, obtain_sensor2top

lyft_categories = (
    "car",
    "truck",
    "bus",
    "emergency_vehicle",
    "other_vehicle",
    "motorcycle",
    "bicycle",
    "pedestrian",
    "animal",
)


def create_lyft_infos(root_path, info_prefix, version="v1.01-train", max_sweeps=10):
    """Create info file of lyft dataset.

    Given the raw data, generate its related info file in pkl format.

    Args:
        root_path (str): Path of the data root.
        info_prefix (str): Prefix of the info file to be generated.
        version (str): Version of the data.
            Default: 'v1.01-train'
        max_sweeps (int): Max number of sweeps.
            Default: 10
    """
    lyft = Lyft(
        data_path=osp.join(root_path, version),
        json_path=osp.join(root_path, version, version),
        verbose=True,
    )
    available_vers = ["v1.01-train", "v1.01-test"]
    assert version in available_vers
    if version == "v1.01-train":
        train_scenes = mmcv.list_from_file("data/lyft/train.txt")
        val_scenes = mmcv.list_from_file("data/lyft/val.txt")
    elif version == "v1.01-test":
        train_scenes = mmcv.list_from_file("data/lyft/test.txt")
        val_scenes = []
    else:
        raise ValueError("unknown")

    # filter existing scenes.
    available_scenes = get_available_scenes(lyft)
    available_scene_names = [s["name"] for s in available_scenes]
    train_scenes = list(filter(lambda x: x in available_scene_names, train_scenes))
    val_scenes = list(filter(lambda x: x in available_scene_names, val_scenes))
    train_scenes = set(
        [
            available_scenes[available_scene_names.index(s)]["token"]
            for s in train_scenes
        ]
    )
    val_scenes = set(
        [available_scenes[available_scene_names.index(s)]["token"] for s in val_scenes]
    )

    test = "test" in version
    if test:
        print(f"test scene: {len(train_scenes)}")
    else:
        print(
            f"train scene: {len(train_scenes)}, \
                val scene: {len(val_scenes)}"
        )
    train_lyft_infos, val_lyft_infos = _fill_trainval_infos(
        lyft, train_scenes, val_scenes, test, max_sweeps=max_sweeps
    )

    metadata = dict(version=version)
    if test:
        print(f"test sample: {len(train_lyft_infos)}")
        data = dict(infos=train_lyft_infos, metadata=metadata)
        info_name = f"{info_prefix}_infos_test"
        info_path = osp.join(root_path, f"{info_name}.pkl")
        mmcv.dump(data, info_path)
    else:
        print(
            f"train sample: {len(train_lyft_infos)}, \
                val sample: {len(val_lyft_infos)}"
        )
        data = dict(infos=train_lyft_infos, metadata=metadata)
        train_info_name = f"{info_prefix}_infos_train"
        info_path = osp.join(root_path, f"{train_info_name}.pkl")
        mmcv.dump(data, info_path)
        data["infos"] = val_lyft_infos
        val_info_name = f"{info_prefix}_infos_val"
        info_val_path = osp.join(root_path, f"{val_info_name}.pkl")
        mmcv.dump(data, info_val_path)


def _fill_trainval_infos(lyft, train_scenes, val_scenes, test=False, max_sweeps=10):
    """Generate the train/val infos from the raw data.

    Args:
        lyft (:obj:`LyftDataset`): Dataset class in the Lyft dataset.
        train_scenes (list[str]): Basic information of training scenes.
        val_scenes (list[str]): Basic information of validation scenes.
        test (bool): Whether use the test mode. In the test mode, no
            annotations can be accessed. Default: False.
        max_sweeps (int): Max number of sweeps. Default: 10.

    Returns:
        tuple[list[dict]]: Information of training set and
            validation set that will be saved to the info file.
    """
    train_lyft_infos = []
    val_lyft_infos = []

    for sample in mmcv.track_iter_progress(lyft.sample):
        lidar_token = sample["data"]["LIDAR_TOP"]
        sd_rec = lyft.get("sample_data", sample["data"]["LIDAR_TOP"])
        cs_record = lyft.get("calibrated_sensor", sd_rec["calibrated_sensor_token"])
        pose_record = lyft.get("ego_pose", sd_rec["ego_pose_token"])
        abs_lidar_path, boxes, _ = lyft.get_sample_data(lidar_token)
        # nuScenes devkit returns more convenient relative paths while
        # lyft devkit returns absolute paths
        abs_lidar_path = str(abs_lidar_path)  # absolute path
        lidar_path = abs_lidar_path.split(f"{os.getcwd()}/")[-1]
        # relative path

        mmcv.check_file_exist(lidar_path)

        info = {
            "lidar_path": lidar_path,
            "token": sample["token"],
            "sweeps": [],
            "cams": dict(),
            "lidar2ego_translation": cs_record["translation"],
            "lidar2ego_rotation": cs_record["rotation"],
            "ego2global_translation": pose_record["translation"],
            "ego2global_rotation": pose_record["rotation"],
            "timestamp": sample["timestamp"],
        }

        l2e_r = info["lidar2ego_rotation"]
        l2e_t = info["lidar2ego_translation"]
        e2g_r = info["ego2global_rotation"]
        e2g_t = info["ego2global_translation"]
        l2e_r_mat = Quaternion(l2e_r).rotation_matrix
        e2g_r_mat = Quaternion(e2g_r).rotation_matrix

        # obtain 6 image's information per frame
        camera_types = [
            "CAM_FRONT",
            "CAM_FRONT_RIGHT",
            "CAM_FRONT_LEFT",
            "CAM_BACK",
            "CAM_BACK_LEFT",
            "CAM_BACK_RIGHT",
        ]
        for cam in camera_types:
            cam_token = sample["data"][cam]
            cam_path, _, cam_intrinsic = lyft.get_sample_data(cam_token)
            cam_info = obtain_sensor2top(
                lyft, cam_token, l2e_t, l2e_r_mat, e2g_t, e2g_r_mat, cam
            )
            cam_info.update(cam_intrinsic=cam_intrinsic)
            info["cams"].update({cam: cam_info})

        # obtain sweeps for a single key-frame
        sd_rec = lyft.get("sample_data", sample["data"]["LIDAR_TOP"])
        sweeps = []
        while len(sweeps) < max_sweeps:
            if not sd_rec["prev"] == "":
                sweep = obtain_sensor2top(
                    lyft, sd_rec["prev"], l2e_t, l2e_r_mat, e2g_t, e2g_r_mat, "lidar"
                )
                sweeps.append(sweep)
                sd_rec = lyft.get("sample_data", sd_rec["prev"])
            else:
                break
        info["sweeps"] = sweeps
        # obtain annotation
        if not test:
            annotations = [
                lyft.get("sample_annotation", token) for token in sample["anns"]
            ]
            locs = np.array([b.center for b in boxes]).reshape(-1, 3)
            dims = np.array([b.wlh for b in boxes]).reshape(-1, 3)
            rots = np.array([b.orientation.yaw_pitch_roll[0] for b in boxes]).reshape(
                -1, 1
            )

            names = [b.name for b in boxes]
            for i in range(len(names)):
                if names[i] in LyftDataset.NameMapping:
                    names[i] = LyftDataset.NameMapping[names[i]]
            names = np.array(names)

            # we need to convert rot to SECOND format.
            gt_boxes = np.concatenate([locs, dims, -rots - np.pi / 2], axis=1)
            assert len(gt_boxes) == len(
                annotations
            ), f"{len(gt_boxes)}, {len(annotations)}"
            info["gt_boxes"] = gt_boxes
            info["gt_names"] = names
            info["num_lidar_pts"] = np.array([a["num_lidar_pts"] for a in annotations])
            info["num_radar_pts"] = np.array([a["num_radar_pts"] for a in annotations])

        if sample["scene_token"] in train_scenes:
            train_lyft_infos.append(info)
        else:
            val_lyft_infos.append(info)

    return train_lyft_infos, val_lyft_infos


def export_2d_annotation(root_path, info_path, version):
    """Export 2d annotation from the info file and raw data.

    Args:
        root_path (str): Root path of the raw data.
        info_path (str): Path of the info file.
        version (str): Dataset version.
    """
    warning.warn(
        "DeprecationWarning: 2D annotations are not used on the "
        "Lyft dataset. The function export_2d_annotation will be "
        "deprecated."
    )
    # get bbox annotations for camera
    camera_types = [
        "CAM_FRONT",
        "CAM_FRONT_RIGHT",
        "CAM_FRONT_LEFT",
        "CAM_BACK",
        "CAM_BACK_LEFT",
        "CAM_BACK_RIGHT",
    ]
    lyft_infos = mmcv.load(info_path)["infos"]
    lyft = Lyft(
        data_path=osp.join(root_path, version),
        json_path=osp.join(root_path, version, version),
        verbose=True,
    )
    # info_2d_list = []
    cat2Ids = [
        dict(id=lyft_categories.index(cat_name), name=cat_name)
        for cat_name in lyft_categories
    ]
    coco_ann_id = 0
    coco_2d_dict = dict(annotations=[], images=[], categories=cat2Ids)
    for info in mmcv.track_iter_progress(lyft_infos):
        for cam in camera_types:
            cam_info = info["cams"][cam]
            coco_infos = get_2d_boxes(
                lyft,
                cam_info["sample_data_token"],
                visibilities=["", "1", "2", "3", "4"],
            )
            (height, width, _) = mmcv.imread(cam_info["data_path"]).shape
            coco_2d_dict["images"].append(
                dict(
                    file_name=cam_info["data_path"],
                    id=cam_info["sample_data_token"],
                    width=width,
                    height=height,
                )
            )
            for coco_info in coco_infos:
                if coco_info is None:
                    continue
                # add an empty key for coco format
                coco_info["segmentation"] = []
                coco_info["id"] = coco_ann_id
                coco_2d_dict["annotations"].append(coco_info)
                coco_ann_id += 1
    mmcv.dump(coco_2d_dict, f"{info_path[:-4]}.coco.json")
